# -*- coding: utf-8 -*- 
# @Time : 2022/4/4 10:27 
# @Author : zzuxyj 
# @File : 13-nn-modelSaveAndLoadAndPretrain.py

"""
1.模型加载
2.模型预训练参数
"""
import torch
import torchvision

# 模型加载 保存方法一
model = torch.load("./model/vgg16_method01.pth")
print(model)

# 模型加载 保存方法二
# 加载torchvision提供的网络模型  , 并且设置预训练参数为false , 即从头开始训练，初始化参数是随机的，不能有很好的效果
vgg16 = torchvision.models.vgg16(pretrained=False)
modelParam = torch.load("./model/vgg16_method02.pth")
vgg16.load_state_dict(modelParam)
print(vgg16)


# 加载自己的模型的时候，确保能够找到自己的网络模型  from Model import *